import yaml
import argparse
import os
import numpy as np
import pandas as pd
import pickle
import sys
from glob import glob
import multiprocessing as mp
import logging
from scipy import sparse
import math

sys.path.insert(0, '/REDACTED/fairness/code/rf/utils')
from data import LoadedInterface, csv_to_h5, hash_file

sys.path.insert(1,'/REDACTED/fairness/code/utilities/')
from costCalculatorV2 import *

cfg = '/REDACTED/fairness/code/rf/config/data-config.yaml'
os.chdir('/REDACTED/fairness/code/rf/config')
stream = open('data-config.yaml', 'r')
out = yaml.safe_load(stream)
print('config file loaded')

# load and clean data

data_interface = LoadedInterface(cfg, rows=None, query=None)
data = data_interface.get(query=None, as_array=False)
inds, features, target, meta, reward, sample_weight, fit_weight = data

# make sure output is all 2D arrays

for item in [features, target, meta, reward, sample_weight, fit_weight]:
    if item.ndim==1:
        item = item[:, np.newaxis]
    print(item.ndim)
    print(item.shape)

# turn the arrays into dataframes
    
for item in [features, meta, reward, sample_weight]:
    item = pd.DataFrame(item)

# merge

for item in [meta, reward, sample_weight]:
    if isinstance(item, pd.Series):
        features = features.merge(item.to_frame(), left_index=True, right_index=True)
    else:
        features = features.merge(item, left_index=True, right_index=True)

output = features

# bit of data cleaning

output = output.loc[(output.study_year>2009) & (output.study_year<2015)]

for x in output.columns:
    if '_x' in x:
        output = output.rename(columns={x: x[:-2]})
        
for x in output.columns:
    if '_y' in x and x!='study_year':
        output = output.drop(columns=x)

# adjust outcome for inflation

output['chg_in_tax_owed_pv'] = [out['inflation-dict'][x]*y for x, y in zip(output.study_year, output.chg_in_tax_owed)]

# windsorize chg_in_tax_owed_pv
output['chg_in_tax_owed_pv'] = output['chg_in_tax_owed_pv'].clip(lower=output['chg_in_tax_owed_pv'].quantile(.01), upper=output['chg_in_tax_owed_pv'].quantile(.99))

# define non-black var

output['predicted_prob_nonblack'] = [(1-x) for x in output['predicted_prob_black']]

# merge eic change indicators

chg_inds = pd.read_csv('/REDACTED/data/raw/research_audits_eic_chg_inds.csv')
chg_inds = chg_inds.drop_duplicates(subset='taxpayer_id', keep='first')
print(len(output))
output = output.merge(chg_inds, how='left', on=['taxpayer_id', 'tax_period', 'study_year'])

print(len(output))

## Create EITC adjustment variables
## disallowance, reduced by more than 20%, and amount reduced (for regressor)
output['eitc_dis'] = output['eitc_amt_to_zero']
output['eitc_red_20'] = np.where((output['eitc_amt']>output['eitc_amt_cor']) & (output['eitc_amt']-output['eitc_amt_cor']>0.2*output['eitc_amt']), 1, 0)
#output['eitc_amt_red'] = np.where(output['eitc_amt']>output['eitc_amt_cor'], output['eitc_amt']-output['eitc_amt_cor'], 0)
output['eitc_amt_red'] = output['eitc_amt']-output['eitc_amt_cor'] 

## non eitc amt red
output['non_eitc_amt_red']= output['chg_in_tax_owed']- output['eitc_amt_red']

## total adjusment of refundable credit:
#output['ref_cred_amt_dif']

## net rev
output = getCostsACOnly(output,acvarb='activity_code',median=False,wins=True)
output.rename(columns = {"exp_cost":"aud_cost_ac"}, inplace = True)
output['net_revenue'] = output['chg_in_tax_owed']-output['aud_cost_ac']

# adjust each EITC adj outcome for inflation
output['eitc_amt_red_pv'] = [out['inflation-dict'][x]*y for x, y in zip(output.study_year, output.eitc_amt_red)]
output['non_eitc_amt_red_pv'] = [out['inflation-dict'][x]*y for x, y in zip(output.study_year, output.non_eitc_amt_red)]
output['ref_cred_amt_dif_pv'] = [out['inflation-dict'][x]*y for x, y in zip(output.study_year, output.ref_cred_amt_dif)]
output['net_revenue_pv'] = [out['inflation-dict'][x]*y for x, y in zip(output.study_year, output.net_revenue)]


## Added 5/9/2023, drop dep_database columns which were pre-processed incorrectly
cols_to_drop = ['REDACTED, AVAILABLE UPON REQUEST WITH APPROPRIATE CLEARANCES FROM IRS']

output = output.drop(columns=cols_to_drop)

### added 06/06/2023, merge in addchild and ref_ed varbs
# merge in dependents (deps) data, using crosswalk

crosswalk = pd.read_csv('/REDACTED/data/raw/research_audits_taxpayer_id_crosswalk_final.csv')
new_varbs = pd.read_csv('/REDACTED/data/raw/research_audits_addchild_refed.csv')
new_varbs.columns=new_varbs.columns.str.lower()

output['taxpayer_id'] = output.taxpayer_id.astype(np.int64)

print(output.taxpayer_id.dtypes)
print(crosswalk.taxpayer_id.dtypes)

output = output.merge(crosswalk, on=['taxpayer_id'], how='left')

print(output.alias.isnull().sum())

output = output.rename(columns={'alias':'taxpayer_id_new'})

print(output.taxpayer_id_new.isnull().sum())

output['taxpayer_id_new'] = output.taxpayer_id_new.fillna(0)

print(output.taxpayer_id_new.isnull().sum())

new_varbs = new_varbs.rename(columns={'taxpayer_id':'taxpayer_id_new'})

output = output.merge(new_varbs, on=['taxpayer_id_new', 'study_year'], how='left')

print(output.taxpayer_id_new.isnull().sum())
output['ref_cred_dis'] = np.where((output.actc_amt > 0) & (output.eitc_amt > 0) & (output.rec_amt > 0) & (output.actc_amt_cor == 0) & (output.eitc_amt_cor == 0) & (output.rec_amt_cor == 0), 1, 0)
print(output.ref_cred_dis.value_counts())
print(output.ref_cred_dis.isnull().sum())

# write out
print(output.head())
print(len(output))

output.to_csv('/REDACTED/fairness/code/rf/data/clean_rf_data.csv', index=False)
print("CSV file written")

### import and add on dep_database features

dep_database = pd.read_csv('/REDACTED/data/raw/receivedData/dep_database/dep_database_research_audits_masked.csv')
dep_database.columns=dep_database.columns.str.lower()

dep_database = dep_database[['REDACTED, AVAILABLE UPON REQUEST WITH APPROPRIATE CLEARANCES FROM IRS']]

dep_database.fillna(0, inplace=True)
dep_database = dep_database.dropna(subset=['taxpayer_id', 'tax_yr', 'taxpayer_id_typ'], how='any')

# clean up rule indicators

rule_vars = [x for x in dep_database.columns.tolist() if 'rule' in x]
for var in rule_vars:
    dep_database[var] = [1 if x!=0 else 0 for x in dep_database[var]]
    
# clean up count_issues

dep_database['count_issues'] = [0 if x=='D' else x for x in dep_database.count_issues]
dep_database['count_issues'] = dep_database.count_issues.astype(float)
    
# dedupe records by max irs_dep_risk_score

dep_database = dep_database.loc[dep_database.reset_index().groupby(['tax_yr', 'taxpayer_id', 'taxpayer_id_typ'])['irs_dep_risk_score'].idxmax()]
assert(dep_database.groupby(['tax_yr', 'taxpayer_id', 'taxpayer_id_typ']).taxpayer_id.count().max()==1)
print('dep_database is unique')

# merge dep_database to output

overlap = [x for x in output.columns if x in dep_database.columns]
print(overlap) # should just be taxpayer_id and study_year (just gettaxpayer_idg taxpayer_id on 8/8/24)

for var in overlap:
    for df in [output, dep_database]:
        df[var] = df[var].astype(float)
        df[var] = [math.floor(x) for x in df[var]]
        df[var] = df[var].astype(str)
 
output = output.merge(dep_database, how='left', on=overlap, validate='1:1')

# assign observations without irs_dep_risk_score scores very low scores
output['irs_dep_risk_score'] = output['irs_dep_risk_score'].fillna(-15000)

# and assign values of 0 to issue count and rules for those same observations
output=output.fillna(0)

# merge in dependents (deps) data, using crosswalk
deps = pd.read_csv('/REDACTED/data/raw/research_audits_deps_varbs.csv')

output['taxpayer_id_new'] = output.taxpayer_id_new.astype(np.int64)
print(output.taxpayer_id_new.dtypes)
print(deps.taxpayer_id.dtypes)

deps = deps.rename(columns={'taxpayer_id':'taxpayer_id_new'})
output_final = output.merge(deps, on=['taxpayer_id_new', 'study_year'], how='left')

print(output_final.taxpayer_id_new.isnull().sum())

output_final['dep_reduced_ind'] = np.where(output_final.child_exemption_HOME_CD_COR < output_final.child_exemption_HOME_CD, 1, 0*output_final.child_exemption_HOME_CD*output_final.child_exemption_HOME_CD_COR)

print(output_final.dep_reduced_ind.value_counts())
print(output_final.dep_reduced_ind.isnull().sum())

# write out
print(output_final.head())

print(output_final.exclude_cond.value_counts())
print(len(output_final))

# write out data- have 17 taxpayers with taxpayer_id_new of 0, although taxpayer_id works as expected
output_final.to_csv('/REDACTED/fairness/code/rf/data/clean_rf_data_plus_dep_database.csv', index=False)